package org.tlang.ast.list.expr;

import org.tlang.ast.AST;
import org.tlang.ast.ASTLeaf;
import org.tlang.ast.ASTList;
import org.tlang.ast.leaf.Name;
import org.tlang.ast.list.array.ArrayAccessor;
import org.tlang.ast.list.func.FunctionInvoker;
import org.tlang.context.Location;
import org.tlang.context.SymbolTable;
import org.tlang.context.TypeTable;
import org.tlang.context.ValueTable;
import org.tlang.exception.EvalException;
import org.tlang.exception.TypeException;
import org.tlang.metaclass.TGetter;
import org.tlang.metaclass.TObject;
import org.tlang.metaclass.TSetter;
import org.tlang.type.BoolType;
import org.tlang.type.ClassType;
import org.tlang.type.FuncType;
import org.tlang.type.Type;

import java.util.List;

/**
 * 二元运算表达式
 */
public class BinaryExpr extends ASTList {
    public BinaryExpr(List<AST> children) {
        super(children);
    }

    public AST left() {
        return child(0);
    }

    public String operator() {
        return ((ASTLeaf) child(1)).token().getText();
    }

    public AST right() {
        return child(2);
    }

    private TablePair lookupAsDotLeft(SymbolTable callerSymbolTable, TypeTable callerTypeTable) {
        TablePair leftObjTable;
        if (leftIsDotExpr()) {
            leftObjTable = ((BinaryExpr) left()).lookupAsDotLeft(callerSymbolTable, callerTypeTable);
        } else {
            leftObjTable = lookupAsDotItem(left(), callerSymbolTable, callerTypeTable,
                    callerSymbolTable, callerTypeTable);
        }

        return lookupAsDotItem(right(), callerSymbolTable, callerTypeTable,
                leftObjTable.symbolTable, leftObjTable.typeTable);
    }

    private TablePair lookupAsDotItem(AST ast, SymbolTable callerSymbolTable, TypeTable callerTypeTable,
                                      SymbolTable objSymbolTable, TypeTable objTypeTable) {
        if (ast instanceof Name) {
            // 1. find item location
            Location location = objSymbolTable.where(((Name) ast).name());
            if (location == null) {
                throw new EvalException("symbol not found", ast);
            }

            // 2. find item type
            Type type = objTypeTable.get(location.nest(), location.index());
            if (!(type instanceof ClassType)) {
                throw new EvalException("bad type", ast);
            }
            ClassType itemType = (ClassType) type;

            ast.lookup(objSymbolTable, objTypeTable);

            return new TablePair(itemType.fieldSymbolTable(), itemType.fieldTypeTable());
        }

        if (ast instanceof FunctionInvoker) {
            Location location = objSymbolTable.where(((FunctionInvoker) ast).name());
            if (location == null) {
                throw new EvalException("symbol not found", ast);
            }

            Type type = objTypeTable.get(location.nest(), location.index());
            if (!(type instanceof FuncType)) {
                throw new EvalException("bad type", ast);
            }
            type = ((FuncType) type).returnType();

            if (!(type instanceof ClassType)) {
                throw new EvalException("bad type", ast);
            }
            ClassType itemType = (ClassType) type;

            ((FunctionInvoker) ast).lookupAsMethodInvoker(callerSymbolTable, callerTypeTable,
                    objSymbolTable, objTypeTable);

            return new TablePair(itemType.fieldSymbolTable(), itemType.fieldTypeTable());
        }

        throw new EvalException("syntax error", ast);
    }

    @Override
    public void lookup(SymbolTable symbolTable, TypeTable typeTable) {
        if (".".equals(operator())) {
            TablePair objLeftTable;
            if (leftIsDotExpr()) {
                objLeftTable = ((BinaryExpr) left()).lookupAsDotLeft(symbolTable, typeTable);
            } else {
                objLeftTable = lookupAsDotItem(left(), symbolTable, typeTable, symbolTable, typeTable);
            }

            lookupRight(symbolTable, typeTable, objLeftTable.symbolTable, objLeftTable.typeTable);
            return;
        }

        super.lookup(symbolTable, typeTable);
    }

    private void lookupRight(SymbolTable callerSymbolTable, TypeTable callerTypeTable,
                             SymbolTable objSymbolTable, TypeTable objTypeTable) {
        AST right = right();
        if (right instanceof Name) {
            right.lookup(objSymbolTable, objTypeTable);
        } else if (right instanceof FunctionInvoker) {
            ((FunctionInvoker) right).lookupAsMethodInvoker(callerSymbolTable, callerTypeTable,
                    objSymbolTable, objTypeTable);
        } else {
            throw new EvalException("syntax error", left());
        }
    }

    @Override
    public Object eval(ValueTable valueTable) {
        Object leftValue;
        Object rightValue;

        switch (operator()) {
            case "=":
                rightValue = right().eval(valueTable);
                if (rightValue instanceof TGetter) {
                    rightValue = ((TGetter) rightValue).get();
                }
                return computeAssign(valueTable, rightValue);

            case ".":
                ValueTable leftObjValueTable = null;
                if (leftIsDotExpr()) {
                    leftObjValueTable = ((BinaryExpr) left()).computeDotLeft(valueTable);
                } else {
                    leftValue = computeDotItem(left(), valueTable, valueTable);
                    if (leftValue instanceof TGetter) {
                        leftValue = ((TGetter) leftValue).get();
                    }
                    if (leftValue instanceof TObject) {
                        leftObjValueTable = (ValueTable) leftValue;
                    }
                }

                if (leftObjValueTable == null) {
                    throw new EvalException("bad type", left());
                }
                return computeDotItem(right(), valueTable, leftObjValueTable);

            default:
                leftValue = left().eval(valueTable);
                rightValue = right().eval(valueTable);
                if (leftValue instanceof TGetter) {
                    leftValue = ((TGetter) leftValue).get();
                }
                if (rightValue instanceof TGetter) {
                    rightValue = ((TGetter) rightValue).get();
                }
                return computeOp(leftValue, operator(), rightValue);
        }
    }

    private Object computeDotItem(AST ast, ValueTable callerValueTable, ValueTable objValueTable) {
        if (ast instanceof Name) {
            return ast.eval(objValueTable);
        } else if (ast instanceof FunctionInvoker) {
            return ((FunctionInvoker) ast).evalAsMethodInvoker(callerValueTable, objValueTable);
        }
        throw new EvalException("syntax error", ast);
    }

    private ValueTable computeDotLeft(ValueTable valueTable) {
        ValueTable leftObjValueTable = null;
        if (leftIsDotExpr()) {
            leftObjValueTable = ((BinaryExpr) left()).computeDotLeft(valueTable);
        } else {
            Object leftValue = computeDotItem(left(), valueTable, valueTable);
            if (leftValue instanceof TGetter) {
                leftValue = ((TGetter) leftValue).get();
            }
            if (leftValue instanceof TObject) {
                leftObjValueTable = (ValueTable) leftValue;
            }
        }

        if (leftObjValueTable == null) {
            throw new EvalException("bad type", left());
        }

        Object rightValue = computeDotItem(right(), valueTable, leftObjValueTable);
        if (rightValue instanceof TGetter) {
            rightValue = ((TGetter) rightValue).get();
        }
        if (rightValue instanceof TObject) {
            return (ValueTable) rightValue;
        } else {
            throw new EvalException("bad type", right());
        }
    }

    private boolean leftIsDotExpr() {
        return left() instanceof BinaryExpr && ".".equals(((BinaryExpr) left()).operator());
    }

    // 赋值运算
    private Object computeAssign(ValueTable valueTable, Object rightValue) {
        AST left = left();
        if (left instanceof ArrayAccessor) {
            TSetter setter = (TSetter) left;
            setter.set(valueTable, rightValue);
            return rightValue;
        } else {
            Object result = left.eval(valueTable);
            if (result instanceof TSetter) {
                ((TSetter) result).set(valueTable, rightValue);
                return rightValue;
            }
        }

        throw new EvalException("bad assignment", this);
    }

    // 算数运算
    private Object computeOp(Object left, String op, Object right) {
        if (left instanceof Long && right instanceof Long) {
            return computeInteger((Long) left, op, (Long) right);
        } else if (left instanceof Double && right instanceof Double) {
            return computeFloat((Double) left, op, (Double) right);
        } else if (left instanceof Boolean && right instanceof Boolean) {
            return computeBoolean((Boolean) left, op, (Boolean) right);
        } else {
            if ("+".equals(op)) {
                return String.valueOf(left) + right;
            } else if ("==".equals(op)) {
                if (left == null) {
                    return right == null;
                } else {
                    return left.equals(right);
                }
            } else {
                throw new EvalException("bad type", this);
            }
        }
    }

    // 整数运算
    private Object computeInteger(Long left, String op, Long right) {
        switch (op) {
            case "+":
                return left + right;
            case "-":
                return left - right;
            case "*":
                return left * right;
            case "/":
                return left / right;
            case "%":
                return left % right;
            case "==":
                return left.equals(right);
            case "!=":
                return !left.equals(right);
            case ">":
                return left > right;
            case "<":
                return left < right;
            case ">=":
                return left >= right;
            case "<=":
                return left <= right;
            default:
                throw new EvalException("bad operator", this);
        }
    }

    // 浮点数运算
    private Object computeFloat(Double left, String op, Double right) {
        switch (op) {
            case "+":
                return left + right;
            case "-":
                return left - right;
            case "*":
                return left * right;
            case "/":
                return left / right;
            case "%":
                return left % right;
            case "==":
                return Double.compare(left, right) == 0;
            case "!=":
                return Double.compare(left, right) != 0;
            case ">":
                return Double.compare(left, right) > 0;
            case "<":
                return Double.compare(left, right) < 0;
            case ">=":
                return Double.compare(left, right) >= 0;
            case "<=":
                return Double.compare(left, right) <= 0;
            default:
                throw new EvalException("bad operator", this);
        }
    }


    // bool运算
    private Object computeBoolean(Boolean left, String op, Boolean right) {
        switch (op) {
            case "&&":
                return left && right;
            case "||":
                return left || right;
            case "==":
                return left == right;
            case "!=":
                return left != right;
            default:
                throw new EvalException("bad operator", this);
        }
    }

    @Override
    public Type typeCheck(TypeTable typeTable) throws TypeException {
        Type leftType = left().typeCheck(typeTable);
        Type rightType = right().typeCheck(typeTable);
        rightType.assertSubTypeOf(leftType, typeTable, this);
        switch (operator()) {
            case "==":
            case "!=":
            case ">":
            case "<":
            case ">=":
            case "<=":
                return BoolType.BOOL_TYPE;
            default:
                return leftType;
        }
    }

    private static class TablePair {
        private final SymbolTable symbolTable;
        private final TypeTable typeTable;

        public TablePair(SymbolTable symbolTable, TypeTable typeTable) {
            this.symbolTable = symbolTable;
            this.typeTable = typeTable;
        }
    }
}
